import argparse
import numpy as np

from dada.model.log_sum_exp.log_sum_exp_runner import LogSumExpRunner
from dada.model.model_enum import Model
from dada.model.norm.norm_runner import NormRunner
from dada.model.polynomial_feasible.poly_feasible_runner import PolynomialFeasibleRunner
from dada.model.worst_instance.worst_instance_runner import WorstCaseRunner

def find_model_runner(model: str, params: dict):
    if model == Model.NORM.value:
        return NormRunner(params)
    elif model == Model.LOG_SUM_EXP.value:
        return LogSumExpRunner(params)
    elif model == Model.POLYNOMIAL_FEASIBILITY.value:
        return PolynomialFeasibleRunner(params)
    elif model == Model.WORST_INSTANCE.value:
        return WorstCaseRunner(params)
    else:
        raise ValueError("model should be one of these options: ['LSE', 'NORM', 'PF', 'WI']")


def run():
    parser = argparse.ArgumentParser(prog="Adaptive Methods Runner")
    parser.add_argument('model',
                        help='Model to run. Specify "LSE" for log-sum-exp, "NORM" for norm and "PF" for polynomial feasible model.')
    parser.add_argument('-d', '--dim', type=int, help='Dimension')
    parser.add_argument('-r', '--radius', type=float, help='Radius')
    parser.add_argument('-n', '--num-polyhedron', type=int, help='Number of polyhedrons')
    parser.add_argument('-s', '--steps', type=int, help='Number of iterations', required=True)
    parser.add_argument('--plot', action='store_true', help='Save plots')
    parser.add_argument('--plot-dir', type=str, default='.', help="Path to the directory to save the plots.")
    parser.add_argument('--model-name', type=str)
    parser.add_argument('--seed', type=int, default=1, help='Random seed')

    parser.add_argument('-ml', '--mu-list', nargs='+', type=float)

    parser.add_argument('-ql', '--q-list', nargs='+', type=float)

    parser.add_argument('-pl', '--p-list', nargs='+', type=float)

    args = parser.parse_args()

    model = args.model
    save_plot = args.plot
    plots_directory = args.plot_dir

    model_name = args.model_name
    if model_name is None:
        model_name = model

    seed = args.seed
    np.random.seed(seed)

    vector_size = args.dim
    radius = args.radius
    num_polyhedron = args.num_polyhedron
    steps = args.steps
    mu_list = args.mu_list
    q_list = args.q_list
    p_list = args.p_list

    params = {
        'vector_size': vector_size,
        'radius': radius,
        'num_polyhedron': num_polyhedron,
        'mu_list': mu_list,
        'q_list': q_list,
        'p_list': p_list,
    }

    model_runner = find_model_runner(model, params)
    model_runner.run(steps, model_name, save_plot, plots_directory)
